/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/

/*!
 * \file common_utils.h
 * \brief
 */

#ifndef UTILS_COMMON_UTILS_H
#define UTILS_COMMON_UTILS_H

#include "integral_constant.h"
namespace Act {
namespace Gemm {
constexpr int64_t MATRIX_INNER_DIM_LIMIT_SIZE = 65536;
constexpr int32_t MATMUL_MNK_ALIGN = 16;
constexpr int32_t MATMUL_MNK_ALIGN_INT8 = 32;
constexpr int64_t DOUBLE_BUFFER_COUNT = 2;
constexpr int64_t UB_FLOAT_ALIGN_NUM = 8;
constexpr int64_t L1_EVENT_ID_OFFSET = 2;
constexpr uint32_t UB_ALIGN_SIZE = 32;
constexpr int MNK_M = 0;
constexpr int MNK_N = 1;
constexpr int MNK_K = 2;
constexpr int MNK_B = 3;
constexpr int MNK_M0 = 4;
constexpr int MNK_N0 = 5;
constexpr static uint64_t B_FULL_LOAD_MODE = 2UL;
constexpr static uint64_t A_FULL_LOAD_MODE = 1UL;
constexpr static int64_t PER_BLOCK_SIZE = 128L;

struct MatmulShape {
    int64_t m;
    int64_t n;
    int64_t k;
    int64_t b;
};

__host_aicore__ inline int64_t CeilDiv(int64_t a, int64_t b)
{
    if (b == 0) {
        return a;
    }
    return (a + b - 1) / b;
}

__host_aicore__ inline int64_t CeilAlign(int64_t a, int64_t b)
{
    if (b == 0) {
        return a;
    }
    return (a + b - 1) / b * b;
}

template <typename T>
__aicore__ inline T Max(T a, T b)
{
    return a > b ? a : b;
}

template <typename T>
__aicore__ inline T Min(T a, T b)
{
    return a > b ? b : a;
}

__aicore__ inline uint64_t CeilDiv(uint64_t a, uint64_t b)
{
    if (b == 0) {
        return a;
    }
    return (a + b - 1) / b;
}

__aicore__ inline uint64_t Align(uint64_t a, uint64_t b)
{
    if (b == 0) {
        return a;
    }
    return (a + b - 1) / b * b;
}

/**
 * Get the aiv corenum from aic in different platforms
 */
__aicore__ inline uint32_t GetAicAivTaskRation()
{
#if defined(__DAV_C310__) || __CCE_AICORE__ == 220
    return 2U; // 2: aic:aiv = 1:2
#else
    return 1U;
#endif
}

template <class T>
struct is_static : AscendC::Std::bool_constant<std::is_empty<T>::value> {};

template <class T>
constexpr bool is_static_v = is_static<AscendC::Std::remove_cvref_t<T>>::value;

template <typename Stride>
struct is_2d_nz_c0_32_impl : AscendC::Std::false_type {};

template <typename T0, typename T1, typename U0, typename U1>
struct is_2d_nz_c0_32_impl<AscendC::Std::tuple<AscendC::Std::tuple<T0, T1>, AscendC::Std::tuple<U0, U1>>>
    : AscendC::Std::bool_constant<AscendC::Std::is_same_v<T0, Act::Gemm::_32> &&
                                  AscendC::Std::is_same_v<T1, Act::Gemm::_512> &&
                                  AscendC::Std::is_same_v<U0, Act::Gemm::_1> && !is_static_v<U1>> {};

template <class Stride>
struct is_2d_nz_c0_32 : is_2d_nz_c0_32_impl<typename AscendC::Std::remove_cvref_t<Stride>> {};
} // namespace Gemm
} // namespace Act
#endif